import torch

"""
默认把tensor中所有长度为1的维度\轴去掉，对剩下的轴重新0123编号
如果指定了轴，若该轴长度为1，则去掉该轴，否则不进行任何操作

什么是长度为1？即该轴上只有一个元素，
例如：shape(2, 3, 1)中的2轴
tensor([[[ 0.6271],
         [-0.0235],
         [-0.5271]],

        [[ 0.0501],
         [-1.1759],
         [ 1.4502]]])
"""

a = torch.randn(2, 3, 4)
print(a)
print("=="* 50)


torch.manual_seed(0)
x = torch.randint(0, 10, (2, 1, 1, 4))
print(x)
print("----------")


y = x.squeeze()
print(y)
print("----------")

y = x.squeeze(0)
print(y)
print("----------")

y = x.squeeze(1)
print(y)
print("----------")

y = x.squeeze(2)
print(y)
print("----------")

y = x.squeeze(3)
print(y)
print("----------")

y = x.squeeze(-1)
print(y)
print("----------")


y = x.squeeze(-2)
print(y)
print("----------")
